//----------------------------------------------------------------------------
//
// Copyright (C) Sartorius Stedim Data Analytics AB 2017 -
//
// Use, modification and distribution are subject to the Boost Software
// License, Version 1.0. (See http://www.boost.org/LICENSE_1_0.txt)
//
//----------------------------------------------------------------------------

#include "EzQ.h"
#include <cassert>
#include <map>

#define ReturnOnError(func) {SQ_ErrorCode err = func; if (err != SQ_E_OK) return err;}


// Class to encapsulate the prepare prediction
class CEzPreparePrediction
{
public:
   CEzPreparePrediction() = default;
   ~CEzPreparePrediction() { SQ_ClearPreparePrediction(&mhPrepPred); }

   operator SQ_PreparePrediction* () { return &mhPrepPred; }
   operator const SQ_PreparePrediction* () const { return &mhPrepPred; }
   operator SQ_PreparePrediction& () { return mhPrepPred; }

protected:
   CEzPreparePrediction(CEzPreparePrediction&) = delete;
   CEzPreparePrediction& operator=(const CEzPreparePrediction&) = delete;
   SQ_PreparePrediction mhPrepPred = nullptr;
};

//////////////////////////////////////////////////////////////////////////
// Constructor 
CEzQ::CEzQ() = default;

//////////////////////////////////////////////////////////////////////////
// Destructor 
CEzQ::~CEzQ()
{
   CloseProject();
}

//////////////////////////////////////////////////////////////////////////
// Opens a given project
SQ_ErrorCode CEzQ::OpenProject(const char* szUSPFile, const char* szPassword)
{
   // Close the project if there is any project open
   CloseProject();

   // Open the USP file in SIMCAQ
   return SQ_OpenProject(szUSPFile, szPassword, &m_hProject);//  Open project 
}

//////////////////////////////////////////////////////////////////////////
// Set the model to use
SQ_ErrorCode CEzQ::SetModel(int iModelIndex)
{
   // Declare variables
   int iNumModels;
   int iModelNumber = -1;
   SQ_Bool bIsFitted;

   assert(m_hProject);
   if (!m_hProject)
      return SQ_E_INVALIDPROJECTHANDLE;

   //////////////////////////////////////////////////////////////////////////
   // Check that the model exists in SIMCAQ
   ReturnOnError(SQ_GetNumberOfModels(m_hProject, &iNumModels));           //  Get number of models

   // Get the model number connected with the input index.
   ReturnOnError(SQ_GetModelNumberFromIndex(m_hProject, iModelIndex, &iModelNumber));	//  Get model number 

   // Store the model number for future use.
   ReturnOnError(SQ_GetModel(m_hProject, iModelNumber, &m_hModel));

   // Check if model is correct (=fitted) 
   if (SQ_IsModelFitted(m_hModel, &bIsFitted) != SQ_E_OK || bIsFitted != SQ_True) //  Check if model is fitted
      return SQ_E_MODELNOTFITTED;

   //////////////////////////////////////////////////////////////////////////
   // Get results
   ReturnOnError(SQ_GetNumberOfPredictiveComponents(m_hModel, &m_iComponents));
   // Set components to use
   ReturnOnError(SQ_InitIntVector(m_oComponents.get(), 1));
   // Set components to use
   ReturnOnError(SQ_SetDataInIntVector(m_oComponents, 1, m_iComponents));
   // Get probability level to use for retrieval of the critical levels
   ReturnOnError(SQ_GetDefaultProbabilityLevel(m_hModel, &m_fProbLevel));
   // Get DCrit
   ReturnOnError(SQ_GetDModXCrit(m_hModel, m_iComponents, SQ_Normalized_True, m_fProbLevel, &m_fDModXCrit));
   // Get TCrit
   ReturnOnError(SQ_GetT2RangeCrit(m_hModel, 1, -1, m_fProbLevel, &m_fHotT2Crit));

   return SQ_E_OK;
}

SQ_ErrorCode CEzQ::Predict(const float* fQuantitativeMatrix, int nQuantRows, int nQuantCols, const char* const* szColumnNames)
{
   assert(m_hModel);
   if (!m_hModel)
      return SQ_E_INVALIDMODELHANDLE;

   // Start preparing the predictions
   CEzPreparePrediction hPrepPred;
   SQ_GetPreparePrediction(m_hModel, hPrepPred);

   // Set prediction data
   SQ_VariableVector hVarVec = nullptr;
   int numPredsetVars;
   SQ_GetVariablesForPrediction(hPrepPred, &hVarVec);
   SQ_GetNumVariablesInVector(hVarVec, &numPredsetVars);

   //TODO: cache this if you can and know that the data is always in the same order, alternatively create a lookup of "varVec".
   std::map<std::string, int> DataLookup;
   for (int iCol = 0; iCol < nQuantCols; ++iCol)
      DataLookup[szColumnNames[iCol]] = iCol;

   // Set the data for the prediction
   if (fQuantitativeMatrix != nullptr) // Must do some kind of alignment
   {
      SQ_Variable hVariable = nullptr;
      SQ_Bool bQualitative;
      for (int iCol = 1; iCol <= numPredsetVars; ++iCol)
      {
         SQ_GetVariableFromVector(hVarVec, iCol, &hVariable);
         SQ_IsQualitative(hVariable, &bQualitative);
         char szName[100];
         SQ_GetVariableName(hVariable, 1, szName, 100);

         if (DataLookup.find(szName) != DataLookup.end())
         {
            int iDataCol = DataLookup[szName];
            for (int iRow = 1; iRow <= nQuantRows; ++iRow)
            {
               if (bQualitative)
               {
                  //not implemented in this sample
                  return SQ_E_ERROR_LAST;
                  //SQ_SetQualitativeData(hPrepPred, iRow, iCol, szQualitativeMatrix[(iRow - 1) * nQuanlCols + iDataCol]);
               }

               SQ_SetQuantitativeData(hPrepPred, iRow, iCol, fQuantitativeMatrix[(iRow - 1) * nQuantCols + iDataCol]);
            }
         }
         else
         {
            // missing a variable
            // predictions can be calculated anyway if at least a few of the variables exists,
            // change here to return SQ_E_OK if you like to predict even if you don't have all the data.
            return SQ_E_INVALIDINDEX;
         }
      }
   }

   // Predict
   SQ_GetPrediction(hPrepPred, &m_hPrediction);

   return SQ_E_OK;
}
//////////////////////////////////////////////////////////////////////////
// Predict with the given prediction set
SQ_ErrorCode CEzQ::Predict2(const float* fQuantitativeMatrix, int nQuantRows, int nQuantCols, const char* const* szQualitativeMatrix, int nQuanlRows, int nQuanlCols)
{
   assert(m_hModel);
   if (!m_hModel)
      return SQ_E_INVALIDMODELHANDLE;

   if (nQuantRows != nQuanlRows)
      return SQ_E_INVALIDINDEX;

   // Start preparing the predictions
   CEzPreparePrediction hPrepPred;
   SQ_GetPreparePrediction(m_hModel, hPrepPred);

   // Set prediction data
   SQ_VariableVector hVarVec = nullptr;
   int numPredsetVars;
   SQ_GetVariablesForPrediction(hPrepPred, &hVarVec);
   SQ_GetNumVariablesInVector(hVarVec, &numPredsetVars);

   if (numPredsetVars != nQuantCols + nQuanlCols) // Must be same number of variables
      return SQ_E_INVALIDINDEX;

   // Set the data for the prediction
   if (fQuantitativeMatrix != nullptr && szQualitativeMatrix != nullptr) // Must do some kind of alignment
   {
      SQ_Variable hVariable = nullptr;
      SQ_Bool bQualitative;
      for (int iCol = 1; iCol <= numPredsetVars; ++iCol)
      {
         int iQualPos = 0;
         int iQuantPos = 0;
         SQ_GetVariableFromVector(hVarVec, iCol, &hVariable);
         SQ_IsQualitative(hVariable, &bQualitative);

         for (int iRow = 1; iRow <= nQuantRows; ++iRow)
         {
            if (bQualitative)
            {
               SQ_SetQualitativeData(hPrepPred, iRow, iCol, szQualitativeMatrix[(iRow - 1) * nQuanlCols + iQualPos]);
               iQualPos++;
            }
            else
            {
               SQ_SetQuantitativeData(hPrepPred, iRow, iCol, fQuantitativeMatrix[(iRow - 1) * nQuantCols + iQuantPos]);
               iQuantPos++;
            }
         }
      }
   }
   else if (fQuantitativeMatrix != nullptr) // Only quantitative data
   {
      assert(nQuantCols == numPredsetVars);
      SQ_SetQuantitativeDataRaw(hPrepPred, nQuantRows, fQuantitativeMatrix);
   }
   else if (szQualitativeMatrix != nullptr) // Only qualitative data
   {
      assert(nQuanlCols == numPredsetVars);
      SQ_StringMatrix oMatrix = nullptr;
      SQ_InitStringMatrix(&oMatrix, nQuanlRows, nQuanlCols);
      SQ_SetStringMatrix(oMatrix, szQualitativeMatrix);
      SQ_SetQualitativeDataMatrix(hPrepPred, oMatrix);
   }

   // Predict
   SQ_GetPrediction(hPrepPred, &m_hPrediction);

   return SQ_E_OK;
}

//////////////////////////////////////////////////////////////////////////
// Close project to free resources
SQ_ErrorCode CEzQ::CloseProject()
{
   // Release the prediction handle
   if (m_hPrediction)
      SQ_ClearPrediction(&m_hPrediction);
   m_hPrediction = nullptr;

   // Close Project
   if (m_hProject)
      SQ_CloseProject(&m_hProject);
   m_hProject = nullptr;

   // Set default values
   m_hModel = nullptr;
   m_iComponents = -1;
   m_fProbLevel = -1;
   m_fDModXCrit = -1;
   m_fHotT2Crit = -1;
   return SQ_E_OK;
}

//////////////////////////////////////////////////////////////////////////
// Returns the predicted T (scores)
SQ_ErrorCode CEzQ::GetTPS(CEzVectorData& fmatResult)
{
   assert(m_hProject);
   assert(m_hPrediction);

   return SQ_GetTPS(m_hPrediction, m_oComponents.get(), fmatResult.get());
}

//////////////////////////////////////////////////////////////////////////
// Returns the predicted DModX
SQ_ErrorCode CEzQ::GetDModXPS(CEzVectorData& fmatResult)
{
   assert(m_hProject);
   assert(m_hPrediction);

   return SQ_GetDModXPS(m_hPrediction, m_oComponents.get(), SQ_Normalized_True, SQ_ModelingPowerWeighted_False, fmatResult.get());
}

//////////////////////////////////////////////////////////////////////////
// Returns the predicted DModY
SQ_ErrorCode CEzQ::GetDModYPS(CEzVectorData& fmatResult)
{
   assert(m_hProject);
   assert(m_hPrediction);

   return SQ_GetDModYPS(m_hPrediction, m_oComponents.get(), SQ_Normalized_True, fmatResult.get());
}

//////////////////////////////////////////////////////////////////////////
// Returns the predicted Y
SQ_ErrorCode CEzQ::GetYPredPS(CEzVectorData& fmatResult)
{
   assert(m_hProject);
   assert(m_hPrediction);

   return SQ_GetYPredPS(m_hPrediction, m_iComponents, SQ_Unscaled_False, SQ_Backtransformed_True, nullptr, fmatResult.get());
}

//////////////////////////////////////////////////////////////////////////
// Returns the predicted T2Range
SQ_ErrorCode CEzQ::GetT2RangePS(CEzVectorData& fmatResult)
{
   assert(m_hProject);
   assert(m_hPrediction);

   return SQ_GetT2RangePS(m_hPrediction, 1, m_iComponents, fmatResult.get());
}

//////////////////////////////////////////////////////////////////////////
// Returns the T2RangeCrit from the model
SQ_ErrorCode CEzQ::GetT2RangeCrit(float& fResult)
{
   assert(m_hProject);

   return SQ_GetT2RangeCrit(m_hModel, 1, m_iComponents, -1, &fResult);
}

//////////////////////////////////////////////////////////////////////////
// Return the description of an error
std::string CEzQ::GetErrorDescription(SQ_ErrorCode eError) const
{
   char szError[256];
   SQ_GetErrorDescription(eError, szError, sizeof(szError));
   return std::string(szError);
}
